diff --git a/megatron/arguments.py b/megatron/arguments.py
index b35af1df..4c36694f 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -20,6 +20,9 @@ import os
 
 import torch
 
+# FastMoE
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
 def parse_args(extra_args_provider=None, defaults={},
                ignore_unknown_args=False):
     """Parse all arguments."""
@@ -42,6 +45,9 @@ def parse_args(extra_args_provider=None, defaults={},
     parser = _add_vit_args(parser)
     parser = _add_logging_args(parser)
 
+    # FastMoE arguments.
+    parser = _add_fmoe_args(parser)
+
     # Custom arguments.
     if extra_args_provider is not None:
         parser = extra_args_provider(parser)
@@ -232,7 +238,11 @@ def parse_args(extra_args_provider=None, defaults={},
         assert args.checkpoint_activations, \
             'for distribute-checkpointed-activations to work you '\
             'need to enable checkpoint-activations'
-
+    # if fmoe_num_experts is not specified,
+    # we are using lower version of megatron,
+    # copy num_experts to fmoe_num_experts
+    if not hasattr(args, 'fmoe_num_experts'):
+        args.fmoe_num_experts = args.num_experts
     _print_args(args)
     return args
 
diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py
index 12510662..32afb2fa 100644
--- a/megatron/data/indexed_dataset.py
+++ b/megatron/data/indexed_dataset.py
@@ -95,7 +95,7 @@ dtypes = {
     3: np.int16,
     4: np.int32,
     5: np.int64,
-    6: np.float,
+    6: np.float32,
     7: np.double,
     8: np.uint16
 }
@@ -268,7 +268,7 @@ class IndexedDatasetBuilder(object):
         np.int16: 2,
         np.int32: 4,
         np.int64: 8,
-        np.float: 4,
+        np.float32: 4,
         np.double: 8
     }
 
diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py
index 823a51f4..32f4b2e1 100644
--- a/megatron/optimizer/__init__.py
+++ b/megatron/optimizer/__init__.py
@@ -69,8 +69,10 @@ def get_megatron_optimizer(model):
 
     # Determine whether the params have main-grad field.
     params_have_main_grad = False
-    if args.DDP_impl == 'local':
-        params_have_main_grad = True
+
+    # FastMoE does not have main_grad field
+    # if args.DDP_impl == 'local':
+    #     params_have_main_grad = True
 
     if args.fp16 or args.bf16:
 
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index 036a1d4c..81d5bd96 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -54,17 +54,23 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
     #   - should not be a replica due to tensor model parallelism
     grads = []
     grads_for_norm = []
+    # FastMoE
+    grads_in_moe = []
     for param in parameters:
         grad_not_none = param.grad is not None
         is_not_shared = param_is_not_shared(param)
         is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
-        grad = param.grad.detach()
         if grad_not_none:
+            grad = param.grad.detach()
             # Make sure the grads are in fp32
             assert param.grad.type() == 'torch.cuda.FloatTensor'
             grads.append(grad)
         if grad_not_none and is_not_shared and is_not_tp_duplicate:
-            grads_for_norm.append(grad)
+            # FastMoE
+            if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+                grads_in_moe.append(grad)
+            else:
+                grads_for_norm.append(grad)
 
     # Norm parameters.
     max_norm = float(max_norm)
@@ -73,6 +79,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
 
     # Calculate norm.
     if norm_type == inf:
+        # FastMoE TODO
+        assert False, f"norm_type {norm_type} is not supported by FastMoE "
         total_norm = max(grad.abs().max() for grad in grads_for_norm)
         total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
         # Take max across all model-parallel GPUs.
@@ -97,7 +105,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
             # we need the pow(norm-type).
             total_norm = grad_norm ** norm_type
 
+            # FastMoE
+            if len(grads_in_moe) > 0 : # 'cold' experts may not have any grads in one iteration
+                grad_norm, _ = multi_tensor_applier(
+                    amp_C.multi_tensor_l2norm,
+                    dummy_overflow_buf,
+                    [grads_in_moe],
+                    False # no per-parameter norm
+                )
+                grad_norm = grad_norm ** norm_type
+                torch.distributed.all_reduce(grad_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group())
+                total_norm += grad_norm
         else:
+            # FastMoE TODO
+            assert False, f"norm_type {norm_type} is not supported by FastMoE "
             for grad in grads_for_norm:
                 grad_norm = torch.norm(grad, norm_type)
                 total_norm += grad_norm ** norm_type
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index 368f5875..080b06f0 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -250,6 +250,9 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
                                                                   param)
                         if hasattr(param, 'shared'):
                             main_param.shared = param.shared
+                        # FastMoE
+                        if hasattr(param, 'dp_comm'):
+                            main_param.dp_comm = param.dp_comm
                         # Replace the optimizer params with the new fp32 copy.
                         param_group['params'][i] = main_param
                         fp32_from_float16_params_this_group.append(main_param)
@@ -396,17 +399,26 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
             # so we can update the loss scale.
             self.grad_scaler.update(found_inf_flag)
 
-            # If we found inf/nan, skip the update.
-            if found_inf_flag:
-                return False, None, None
+            # move to L417.
+            # if found_inf_flag:
+            #     return False, None, None
 
         # Clip the main gradients.
         timers('optimizer-clip-main-grad').start()
         grad_norm = None
-        if self.clip_grad > 0.0:
-            grad_norm = self.clip_grad_norm(self.clip_grad)
+        
+        # remove if branch to avoid dead-lock in FastMoE
+        # if self.clip_grad > 0.0:
+        #     grad_norm = self.clip_grad_norm(self.clip_grad)
+        grad_norm = self.clip_grad_norm(self.clip_grad)
+        
         timers('optimizer-clip-main-grad').stop()
 
+        # move early return to here to avoid dead-lock in FastMoE 
+        # If we found inf/nan, skip the update.
+        if found_inf_flag:
+            return False, None, None
+
         # count the zeros in the grads
         num_zeros_in_grad = self.count_zeros() if \
                             self.log_num_zeros_in_grad else None
diff --git a/megatron/schedules.py b/megatron/schedules.py
index d346c30d..8eef46c8 100644
--- a/megatron/schedules.py
+++ b/megatron/schedules.py
@@ -23,7 +23,11 @@ from megatron import get_timers
 from megatron import mpu
 from megatron import p2p_communication
 from megatron.utils import unwrap_model
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
 from megatron.model import Float16Module
 
 def get_forward_backward_func():
@@ -54,7 +58,8 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
     unwrapped_model = unwrap_model(
         model, (torchDDP, LocalDDP, Float16Module))
     unwrapped_model.set_input_tensor(input_tensor)
-    output_tensor, loss_func = forward_step_func(data_iterator, model)
+    output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model)
+    bal_loss = bal_loss / get_num_microbatches()
     if mpu.is_pipeline_last_stage():
         output_tensor = loss_func(output_tensor)
         loss, loss_reduced = output_tensor
@@ -62,10 +67,10 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
         losses_reduced.append(loss_reduced)
     timers('forward-compute').stop()
 
-    return output_tensor
+    return output_tensor, bal_loss
 
 
-def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
+def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss):
     """Backward step through passed-in output tensor.
 
     If last stage, output_tensor_grad is None, otherwise gradient of loss
@@ -85,7 +90,9 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
     # Backward pass.
     if output_tensor_grad is None:
         output_tensor = optimizer.scale_loss(output_tensor)
-    torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
+        torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
+    else:
+        torch.autograd.backward([output_tensor,bal_loss], grad_tensors=[output_tensor_grad, None])
 
     # Collect the grad of the input_tensor.
     input_tensor_grad = None
@@ -122,18 +129,18 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
     input_tensor, output_tensor_grad = None, None
     with context_handler():
         for i in range(get_num_microbatches() - 1):
-            output_tensor = forward_step(forward_step_func, data_iterator, model,
+            output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
                                          input_tensor, losses_reduced)
             if not forward_only:
                 backward_step(optimizer, input_tensor, output_tensor,
-                              output_tensor_grad)
+                              output_tensor_grad, bal_loss)
 
     # Run computation for last microbatch out of context handler (want to
     # synchronize gradients).
-    output_tensor = forward_step(forward_step_func, data_iterator, model,
+    output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
                                  input_tensor, losses_reduced)
     if not forward_only:
-        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
+        backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss)
 
     return losses_reduced
 
@@ -144,6 +151,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
     communication between pipeline stages as needed.
 
     Returns dictionary with losses if the last stage, empty dict otherwise."""
+    # FastMoE TODO
+    assert False, "FastMoE not supports pipeline with interleaving"
+
     input_tensors = [[] for _ in range(len(model))]
     output_tensors = [[] for _ in range(len(model))]
     losses_reduced = []
@@ -385,17 +395,19 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
 
     input_tensors = []
     output_tensors = []
+    bal_losses = []
     losses_reduced = []
 
     # Run warmup forward passes.
     for i in range(num_warmup_microbatches):
         input_tensor = p2p_communication.recv_forward(timers=timers)
-        output_tensor = forward_step(forward_step_func, data_iterator, model,
+        output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
                                      input_tensor, losses_reduced)
         p2p_communication.send_forward(output_tensor, timers=timers)
 
         input_tensors.append(input_tensor)
         output_tensors.append(output_tensor)
+        bal_losses.append(bal_loss)
 
     # Before running 1F1B, need to receive first forward tensor.
     # If all microbatches are run in warmup / cooldown phase, then no need to
@@ -407,7 +419,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
     for i in range(num_microbatches_remaining):
         last_iteration = (i == (num_microbatches_remaining - 1))
 
-        output_tensor = forward_step(forward_step_func, data_iterator, model,
+        output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
                                      input_tensor, losses_reduced)
         if forward_only:
             p2p_communication.send_forward(output_tensor, timers=timers)
@@ -420,16 +432,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
         # start of the list for backward pass.
         input_tensors.append(input_tensor)
         output_tensors.append(output_tensor)
+        bal_losses.append(bal_loss)
 
         if forward_only:
             if not last_iteration:
                 input_tensor = p2p_communication.recv_forward(timers=timers)
         else:
-            input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
+            input_tensor, output_tensor, bal_loss = input_tensors.pop(0), output_tensors.pop(0), bal_losses.pop(0)
 
             input_tensor_grad = \
                 backward_step(optimizer, input_tensor, output_tensor,
-                              output_tensor_grad)
+                              output_tensor_grad, bal_loss)
 
             if last_iteration:
                 input_tensor = None
@@ -444,12 +457,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
         for i in range(num_warmup_microbatches):
             input_tensor = input_tensors.pop(0)
             output_tensor = output_tensors.pop(0)
+            bal_loss = bal_losses.pop(0)
 
             output_tensor_grad = p2p_communication.recv_backward(timers=timers)
 
             input_tensor_grad = \
                 backward_step(optimizer, input_tensor, output_tensor,
-                              output_tensor_grad)
+                              output_tensor_grad, bal_loss)
 
             p2p_communication.send_backward(input_tensor_grad, timers=timers)
 
diff --git a/megatron/training.py b/megatron/training.py
index 1ab57e9c..fbe2fe8e 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -35,14 +35,23 @@ from megatron import update_num_microbatches
 from megatron import mpu
 from megatron import print_rank_0
 from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+
+# FastMoE
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
+
 from megatron.model import Float16Module
 from megatron.optimizer import get_megatron_optimizer
 from megatron.initialize import initialize_megatron
 from megatron.initialize import write_args_to_tensorboard
 from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
 from megatron.utils import check_adlr_autoresume_termination
 from megatron.utils import unwrap_model
 from megatron.data.data_samplers import build_pretraining_data_loader
@@ -107,6 +116,13 @@ def pretrain(train_valid_test_dataset_provider,
     args = get_args()
     timers = get_timers()
 
+    # Initialize FastMoE
+    if args.fmoefy:
+        from fmoe.megatron import patch_forward_step, patch_model_provider
+
+        forward_step_func = patch_forward_step(forward_step_func, Megatron_Version="v2.5")
+        model_provider = patch_model_provider(model_provider, Megatron_Version='v2.5')
+
     # Model, optimizer, and learning rate.
     timers('model-and-optimizer-setup').start()
     model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -386,10 +402,12 @@ def train_step(forward_step_func, data_iterator,
 
         if unwrapped_model.share_word_embeddings:
             word_embeddings_weight = unwrapped_model.word_embeddings_weight()
-            if args.DDP_impl == 'local':
-                grad = word_embeddings_weight.main_grad
-            else:
-                grad = word_embeddings_weight.grad
+            grad = word_embeddings_weight.grad
+            # FastMoE does not have main_grad field
+            # if args.DDP_impl == 'local':
+            #     grad = word_embeddings_weight.main_grad
+            # else:
+            #     grad = word_embeddings_weight.grad
             torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
     timers('backward-embedding-all-reduce').stop()
 
@@ -458,26 +476,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
     # Logging.
     timers_to_log = []
 
-    def add_to_logging(name):
-        if name in timers.timers:
+    # FastMoE add several timers.
+    # For simplicity, add all timers to log.
+    def add_all():
+        for name in timers.timers:
             timers_to_log.append(name)
-    add_to_logging('forward-compute')
-    add_to_logging('forward-recv')
-    add_to_logging('forward-send')
-    add_to_logging('forward-backward-send-forward-backward-recv')
-    add_to_logging('backward-compute')
-    add_to_logging('backward-recv')
-    add_to_logging('backward-send')
-    add_to_logging('backward-send-forward-recv')
-    add_to_logging('backward-send-backward-recv')
-    add_to_logging('backward-params-all-reduce')
-    add_to_logging('backward-embedding-all-reduce')
-    add_to_logging('optimizer-copy-to-main-grad')
-    add_to_logging('optimizer-unscale-and-check-inf')
-    add_to_logging('optimizer-clip-main-grad')
-    add_to_logging('optimizer-copy-main-to-model-params')
-    add_to_logging('optimizer')
-    add_to_logging('batch-generator')
+
+    add_all()
 
     # Calculate batch size.
     batch_size = args.micro_batch_size * args.data_parallel_size * \

